from PIL import Image, ImageDraw
import json
import os
import re
import pdb
from tqdm import tqdm
import random
import argparse
import jsonlines
import ast
import torch

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import numpy as np


imgs_dir =  "Your data path/Mind2Web/images"
anno_dir = "Your data path/Mind2Web"


_MIND2WEB_SYSTEM_ADD_DESCRIPTION_THINKING = """You are an assistant trained to navigate the web. 
Given a task instruction, a screenshot, and a last history action summary, output the think and ext action and wait for the next observation. 
The think must strictly follow these reasoning steps:
(1) Progress Estimation: Interface Comprehension and Progress Estimation
(2) Decesion Reasoning: Strategy Formulation
(3) History Summary: Update the history action summary according to the last history action summary and the action you executed.

## Action Space
1. `CLICK`: Click on an element, value is the element to click and the position [x,y] is required.
2. `TYPE`: Type a string into an element, value is the string to type and the position [x,y] is required.
3. `SELECT`: Select a value for an element, value is the value to select and the position [x,y] is required.
Position represents the relative coordinates on the screenshot and should be scaled to a range of 0-1.

## Output Format
<Progress Estimation>
...
</Progress Estimation>
<Decesion Reasoning>
...
</Decesion Reasoning>
<answer>
{{'action': 'ACTION_TYPE', 'value': 'element', 'position': [x,y]}}
</answer>
<History Summary>
...
</History Summary>

If value or position is not applicable, set it as `None`.
Position represents the relative coordinates on the screenshot and should be scaled to a range of 0-1.
"""


def get_bbox(action, image_size):
    bbox = [action["bbox"]["x"], action["bbox"]["y"], action["bbox"]["x"] + action["bbox"]["width"],
            action["bbox"]["y"] + action["bbox"]["height"]]
    bbox = [bbox[0] / image_size[0], bbox[1] / image_size[1], bbox[2] / image_size[0], bbox[3] / image_size[1]]
    bbox = [round(item, 3) for item in bbox]
    return bbox

def get_value(step_repr):
    pattern = r'\]\s+(.*?)\s+->'
    match = re.search(pattern, step_repr)
    if match:
        return match.group(1)
    else:
        return None

def get_answer(sample, step, step_repr):
    image = sample['img_url']
    image_size = sample['img_size']
    task = sample['task']

    action_type = step['operation']['op']
    if action_type != 'TYPE':
        element = get_value(step_repr)
    else:
        element = step['operation']['value']
    bbox = step['bbox']
    point_x = bbox["x"] + (bbox["width"] / 2)
    point_y = bbox["y"] + (bbox["height"] / 2)
    click_point = [point_x / image_size[0], point_y / image_size[1]]
    click_point = [round(item, 2) for item in click_point]
    answer = {'action': action_type, 'value': element, 'position': click_point}
    return answer

def data_transform(version='train', mini=False):
    mind2web_train = json.load(open(f"{anno_dir}/mind2web_data_{version}.json", 'r'))

    total_step = []
    step_i = 0

    for episode in tqdm(mind2web_train):
        annot_id = episode["annotation_id"]
        confirmed_task = episode["confirmed_task"]


        previous_actions = []
        previous_images = []
        for idx, (step, step_repr) in enumerate(zip(episode["actions"], episode["action_reprs"])):
            filename = annot_id + '-' + step["action_uid"] + '.jpg'
            img_path = os.path.join(imgs_dir, filename)

            if not os.path.exists(img_path):
                continue
            with Image.open(img_path) as image:

                previous_step = ""
                for i, action in enumerate(previous_actions):
                    previous_step += 'Step' + str(i) + ', previous action: ' + action[:-1] + "}. "

                action_history = []
                num_history = 4
                for i, action in enumerate(previous_actions[-num_history:]):                         
                    action_history.append({"type": "text", "text": f'Step {i}: {action}' })

                item = {
                    'img_url': filename,
                    'img_size': image.size,
                    'task': confirmed_task
                }
                answer_dict = get_answer(item, step, step_repr)
                cur_answer = str(answer_dict)
                previous_actions.append(cur_answer)


                prompt = _MIND2WEB_SYSTEM_ADD_DESCRIPTION_THINKING 

                bbox_ref = get_bbox(step, image.size)

                next_id = step_i + 1 if idx != episode["actions"] else step_i
                data = {
                    "dataset_name": "Mind2Web_train",
                    "id": step_i,
                    "image": img_path,
                    "problem": prompt,
                    "solution": cur_answer,
                    "task": confirmed_task,
                    "history": previous_step,
                    "bbox_ref": bbox_ref,
                    "next_id": next_id,
                }
                total_step.append(data)

                step_i += 1

    return total_step

if __name__ == "__main__":
    for version in ['train']:
        train_step = data_transform(version=version)
    
    save_url = "Your save path"
    with jsonlines.open(save_url, mode="w") as writer:
        writer.write_all(train_step)

    # test_full = []
    # for version in ['test_task', 'test_domain', 'test_website']:
    #     test_full.extend(data_transform(version=version))
    # save_url = "Your save path"
    # with jsonlines.open(save_url, mode="w") as writer:
    #     writer.write_all(test_full)